import pandas as pd
import torch
from sklearn.impute import SimpleImputer
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn import preprocessing
Problem Statement
The goal of this competition is to predict if a person has any of three medical conditions. You are being asked to predict if the person has one or more of any of the three medical conditions (Class 1), or none of the three medical conditions (Class 0). You will create a model trained on measurements of health characteristics. I came in top 17% of the teams who participated in the competition.
Installing Required Libraries
from fastai.tabular.all import *
= '{:.2f}'.format
pd.options.display.float_format # set_seed(42)
print(torch.backends.mps.is_available())
False
= torch.device("mps")
device torch.set_default_device(device)
import os
= os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')
isKaggleEnv isKaggleEnv
'Batch'
# install fastkaggle if not available
if not isKaggleEnv:
try: import fastkaggle
except ModuleNotFoundError:
!pip install -Uq fastkaggle
from fastkaggle import *
Reading Data
'display.max_columns', None)
pd.set_option('display.max_rows', None) pd.set_option(
if isKaggleEnv:
= Path('../input/icr-identify-age-related-conditions')
path else:
= 'icr-identify-age-related-conditions'
comp = setup_comp(comp, install='fastai "timm>=0.6.2.dev0"') path
path.ls()
(#4) [Path('../input/icr-identify-age-related-conditions/sample_submission.csv'),Path('../input/icr-identify-age-related-conditions/greeks.csv'),Path('../input/icr-identify-age-related-conditions/train.csv'),Path('../input/icr-identify-age-related-conditions/test.csv')]
= pd.read_csv(path/"greeks.csv")
greek_df = pd.read_csv(path/"train.csv")
df = pd.read_csv(path/"test.csv") test_df
10) df.head(
Id | AB | AF | AH | AM | AR | AX | AY | AZ | BC | BD | BN | BP | BQ | BR | BZ | CB | CC | CD | CF | CH | CL | CR | CS | CU | CW | DA | DE | DF | DH | DI | DL | DN | DU | DV | DY | EB | EE | EG | EH | EJ | EL | EP | EU | FC | FD | FE | FI | FL | FR | FS | GB | GE | GF | GH | GI | GL | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 000ff2bfdfe9 | 0.21 | 3109.03 | 85.20 | 22.39 | 8.14 | 0.70 | 0.03 | 9.81 | 5.56 | 4126.59 | 22.60 | 175.64 | 152.71 | 823.93 | 257.43 | 47.22 | 0.56 | 23.39 | 4.85 | 0.02 | 1.05 | 0.07 | 13.78 | 1.30 | 36.21 | 69.08 | 295.57 | 0.24 | 0.28 | 89.25 | 84.32 | 29.66 | 5.31 | 1.74 | 23.19 | 7.29 | 1.99 | 1433.17 | 0.95 | B | 30.88 | 78.53 | 3.83 | 13.39 | 10.27 | 9028.29 | 3.58 | 7.30 | 1.74 | 0.09 | 11.34 | 72.61 | 2003.81 | 22.14 | 69.83 | 0.12 | 1 |
1 | 007255e47698 | 0.15 | 978.76 | 85.20 | 36.97 | 8.14 | 3.63 | 0.03 | 13.52 | 1.23 | 5496.93 | 19.42 | 155.87 | 14.75 | 51.22 | 257.43 | 30.28 | 0.48 | 50.63 | 6.09 | 0.03 | 1.11 | 1.12 | 28.31 | 1.36 | 37.48 | 70.80 | 178.55 | 0.24 | 0.36 | 110.58 | 75.75 | 37.53 | 0.01 | 1.74 | 17.22 | 4.93 | 0.86 | 1111.29 | 0.00 | A | 109.13 | 95.42 | 52.26 | 17.18 | 0.30 | 6785.00 | 10.36 | 0.17 | 0.50 | 0.57 | 9.29 | 72.61 | 27981.56 | 29.14 | 32.13 | 21.98 | 0 |
2 | 013f2bd269f5 | 0.47 | 2635.11 | 85.20 | 32.36 | 8.14 | 6.73 | 0.03 | 12.82 | 1.23 | 5135.78 | 26.48 | 128.99 | 219.32 | 482.14 | 257.43 | 32.56 | 0.50 | 85.96 | 5.38 | 0.04 | 1.05 | 0.70 | 39.36 | 1.01 | 21.46 | 70.82 | 321.43 | 0.24 | 0.21 | 120.06 | 65.47 | 28.05 | 1.29 | 1.74 | 36.86 | 7.81 | 8.15 | 1494.08 | 0.38 | B | 109.13 | 78.53 | 5.39 | 224.21 | 8.75 | 8338.91 | 11.63 | 7.71 | 0.98 | 1.20 | 37.08 | 88.61 | 13676.96 | 28.02 | 35.19 | 0.20 | 0 |
3 | 043ac50845d5 | 0.25 | 3819.65 | 120.20 | 77.11 | 8.14 | 3.69 | 0.03 | 11.05 | 1.23 | 4169.68 | 23.66 | 237.28 | 11.05 | 661.52 | 257.43 | 15.20 | 0.72 | 88.16 | 2.35 | 0.03 | 1.40 | 0.64 | 41.12 | 0.72 | 21.53 | 47.28 | 196.61 | 0.24 | 0.29 | 139.82 | 71.57 | 24.35 | 2.66 | 1.74 | 52.00 | 7.39 | 3.81 | 15691.55 | 0.61 | B | 31.67 | 78.53 | 31.32 | 59.30 | 7.88 | 10965.77 | 14.85 | 6.12 | 0.50 | 0.28 | 18.53 | 82.42 | 2094.26 | 39.95 | 90.49 | 0.16 | 0 |
4 | 044fb8a146ec | 0.38 | 3733.05 | 85.20 | 14.10 | 8.14 | 3.94 | 0.05 | 3.40 | 102.15 | 5728.73 | 24.01 | 324.55 | 149.72 | 6074.86 | 257.43 | 82.21 | 0.54 | 72.64 | 30.54 | 0.03 | 1.05 | 0.69 | 31.72 | 0.83 | 34.42 | 74.07 | 200.18 | 0.24 | 0.21 | 97.92 | 52.84 | 26.02 | 1.14 | 1.74 | 9.06 | 7.35 | 3.49 | 1403.66 | 0.16 | B | 109.13 | 91.99 | 51.14 | 29.10 | 4.27 | 16198.05 | 13.67 | 8.15 | 48.50 | 0.12 | 16.41 | 146.11 | 8524.37 | 45.38 | 36.26 | 0.10 | 1 |
5 | 04517a3c90bd | 0.21 | 2615.81 | 85.20 | 8.54 | 8.14 | 4.01 | 0.03 | 12.55 | 1.23 | 5237.54 | 10.24 | 148.49 | 16.53 | 642.33 | 257.43 | 18.38 | 0.64 | 80.67 | 14.69 | 0.02 | 1.05 | 0.86 | 32.46 | 1.39 | 7.03 | 55.22 | 135.49 | 0.24 | 0.48 | 135.32 | 81.46 | 31.73 | 0.01 | 1.74 | 16.77 | 4.93 | 2.39 | 866.38 | 0.00 | A | 109.13 | 78.53 | 3.83 | 23.30 | 0.30 | 8517.28 | 10.98 | 0.17 | 0.50 | 1.16 | 21.92 | 72.61 | 24177.60 | 28.53 | 82.53 | 21.98 | 0 |
6 | 049232ca8356 | 0.35 | 1733.65 | 85.20 | 8.38 | 15.31 | 1.91 | 0.03 | 6.55 | 1.23 | 5710.46 | 17.66 | 143.65 | 344.64 | 719.73 | 257.43 | 38.46 | 0.95 | 78.30 | 13.18 | 0.03 | 1.05 | 0.61 | 13.78 | 2.79 | 21.88 | 19.22 | 107.91 | 1.32 | 0.46 | 176.63 | 97.08 | 44.51 | 1.01 | 1.74 | 4.47 | 4.93 | 2.62 | 1793.61 | 0.10 | B | 13.21 | 78.53 | 26.30 | 48.27 | 1.46 | 3903.81 | 10.78 | 4.41 | 0.86 | 0.47 | 17.88 | 192.45 | 3332.47 | 34.17 | 100.09 | 0.07 | 0 |
7 | 057287f2da6d | 0.27 | 966.45 | 85.20 | 21.17 | 8.14 | 4.99 | 0.03 | 9.41 | 1.23 | 5040.78 | 20.83 | 170.05 | 6.20 | 701.02 | 257.43 | 12.87 | 0.77 | 71.54 | 24.91 | 0.03 | 1.05 | 1.11 | 41.93 | 1.19 | 42.11 | 63.22 | 326.23 | 0.24 | 0.33 | 83.77 | 73.99 | 19.91 | 2.12 | 1.74 | 15.87 | 8.35 | 3.28 | 767.72 | 0.29 | B | 15.09 | 104.99 | 5.10 | 37.56 | 4.52 | 18090.35 | 10.34 | 6.59 | 0.50 | 0.28 | 18.45 | 109.69 | 21371.76 | 35.21 | 31.42 | 0.09 | 0 |
8 | 0594b00fb30a | 0.35 | 3238.44 | 85.20 | 28.89 | 8.14 | 4.02 | 0.03 | 8.24 | 3.63 | 6569.37 | 20.48 | 135.88 | NaN | 601.80 | 257.43 | 116.10 | 0.86 | 93.23 | 14.57 | 0.03 | 1.05 | 1.05 | 29.91 | 1.47 | 43.02 | 76.77 | 231.13 | 0.24 | 0.33 | 131.35 | 98.17 | 29.47 | 0.61 | 1.74 | 7.20 | 10.77 | 1.34 | 3004.93 | 0.07 | B | NaN | 78.53 | 56.61 | 35.30 | 1.39 | 3380.03 | 11.45 | 4.76 | 1.18 | 0.07 | 17.25 | 147.22 | 4589.61 | 29.77 | 54.68 | 0.07 | 0 |
9 | 05f2bc0155cd | 0.32 | 5188.68 | 85.20 | 12.97 | 8.14 | 4.59 | 0.03 | 10.69 | 1.23 | 4951.70 | 21.89 | 202.17 | 107.28 | 906.61 | 257.43 | 41.65 | 0.72 | 58.80 | 29.10 | 0.03 | 1.05 | 0.71 | 36.79 | 1.61 | 36.72 | 31.18 | 403.25 | 0.24 | 0.35 | 129.18 | 106.52 | 22.32 | 0.01 | 1.74 | 16.71 | 9.10 | 3.66 | 1637.13 | 0.00 | A | 91.47 | 99.82 | 3.83 | 29.78 | 0.30 | 3142.39 | 12.33 | 0.17 | 1.57 | 0.32 | 24.52 | 98.93 | 5563.13 | 21.99 | 33.30 | 21.98 | 0 |
df.shape
(617, 58)
Let’s look at the data
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 617 entries, 0 to 616
Data columns (total 58 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Id 617 non-null object
1 AB 617 non-null float64
2 AF 617 non-null float64
3 AH 617 non-null float64
4 AM 617 non-null float64
5 AR 617 non-null float64
6 AX 617 non-null float64
7 AY 617 non-null float64
8 AZ 617 non-null float64
9 BC 617 non-null float64
10 BD 617 non-null float64
11 BN 617 non-null float64
12 BP 617 non-null float64
13 BQ 557 non-null float64
14 BR 617 non-null float64
15 BZ 617 non-null float64
16 CB 615 non-null float64
17 CC 614 non-null float64
18 CD 617 non-null float64
19 CF 617 non-null float64
20 CH 617 non-null float64
21 CL 617 non-null float64
22 CR 617 non-null float64
23 CS 617 non-null float64
24 CU 617 non-null float64
25 CW 617 non-null float64
26 DA 617 non-null float64
27 DE 617 non-null float64
28 DF 617 non-null float64
29 DH 617 non-null float64
30 DI 617 non-null float64
31 DL 617 non-null float64
32 DN 617 non-null float64
33 DU 616 non-null float64
34 DV 617 non-null float64
35 DY 617 non-null float64
36 EB 617 non-null float64
37 EE 617 non-null float64
38 EG 617 non-null float64
39 EH 617 non-null float64
40 EJ 617 non-null object
41 EL 557 non-null float64
42 EP 617 non-null float64
43 EU 617 non-null float64
44 FC 616 non-null float64
45 FD 617 non-null float64
46 FE 617 non-null float64
47 FI 617 non-null float64
48 FL 616 non-null float64
49 FR 617 non-null float64
50 FS 615 non-null float64
51 GB 617 non-null float64
52 GE 617 non-null float64
53 GF 617 non-null float64
54 GH 617 non-null float64
55 GI 617 non-null float64
56 GL 616 non-null float64
57 Class 617 non-null int64
dtypes: float64(55), int64(1), object(2)
memory usage: 279.7+ KB
# Statstical Summary of the data
df.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
AB | 617.00 | 0.48 | 0.47 | 0.08 | 0.25 | 0.35 | 0.56 | 6.16 |
AF | 617.00 | 3502.01 | 2300.32 | 192.59 | 2197.35 | 3120.32 | 4361.64 | 28688.19 |
AH | 617.00 | 118.62 | 127.84 | 85.20 | 85.20 | 85.20 | 113.74 | 1910.12 |
AM | 617.00 | 38.97 | 69.73 | 3.18 | 12.27 | 20.53 | 39.14 | 630.52 |
AR | 617.00 | 10.13 | 10.52 | 8.14 | 8.14 | 8.14 | 8.14 | 178.94 |
AX | 617.00 | 5.55 | 2.55 | 0.70 | 4.13 | 5.03 | 6.43 | 38.27 |
AY | 617.00 | 0.06 | 0.42 | 0.03 | 0.03 | 0.03 | 0.04 | 10.32 |
AZ | 617.00 | 10.57 | 4.35 | 3.40 | 8.13 | 10.46 | 12.97 | 38.97 |
BC | 617.00 | 8.05 | 65.17 | 1.23 | 1.23 | 1.23 | 5.08 | 1463.69 |
BD | 617.00 | 5350.39 | 3021.33 | 1693.62 | 4155.70 | 4997.96 | 6035.89 | 53060.60 |
BN | 617.00 | 21.42 | 3.48 | 9.89 | 19.42 | 21.19 | 23.66 | 29.31 |
BP | 617.00 | 231.32 | 183.99 | 72.95 | 156.85 | 193.91 | 247.80 | 2447.81 |
BQ | 557.00 | 98.33 | 96.48 | 1.33 | 27.83 | 61.64 | 134.01 | 344.64 |
BR | 617.00 | 1218.13 | 7575.29 | 51.22 | 424.99 | 627.42 | 975.65 | 179250.25 |
BZ | 617.00 | 550.63 | 2076.37 | 257.43 | 257.43 | 257.43 | 257.43 | 50092.46 |
CB | 615.00 | 77.10 | 159.05 | 12.50 | 23.32 | 42.55 | 77.31 | 2271.44 |
CC | 614.00 | 0.69 | 0.26 | 0.18 | 0.56 | 0.66 | 0.77 | 4.10 |
CD | 617.00 | 90.25 | 51.59 | 23.39 | 64.72 | 79.82 | 99.81 | 633.53 |
CF | 617.00 | 11.24 | 13.57 | 0.51 | 5.07 | 9.12 | 13.57 | 200.97 |
CH | 617.00 | 0.03 | 0.01 | 0.00 | 0.02 | 0.03 | 0.03 | 0.22 |
CL | 617.00 | 1.40 | 1.92 | 1.05 | 1.05 | 1.05 | 1.23 | 31.69 |
CR | 617.00 | 0.74 | 0.28 | 0.07 | 0.59 | 0.73 | 0.86 | 3.04 |
CS | 617.00 | 36.92 | 17.27 | 13.78 | 29.78 | 34.84 | 40.53 | 267.94 |
CU | 617.00 | 1.38 | 0.54 | 0.14 | 1.07 | 1.35 | 1.66 | 4.95 |
CW | 617.00 | 27.17 | 14.65 | 7.03 | 7.03 | 36.02 | 37.94 | 64.52 |
DA | 617.00 | 51.13 | 21.21 | 6.91 | 37.94 | 49.18 | 61.41 | 210.33 |
DE | 617.00 | 401.90 | 317.75 | 36.00 | 188.82 | 307.51 | 507.90 | 2103.41 |
DF | 617.00 | 0.63 | 1.91 | 0.24 | 0.24 | 0.24 | 0.24 | 37.90 |
DH | 617.00 | 0.37 | 0.11 | 0.04 | 0.30 | 0.36 | 0.43 | 1.06 |
DI | 617.00 | 146.97 | 86.08 | 60.23 | 102.70 | 130.05 | 165.84 | 1049.17 |
DL | 617.00 | 94.80 | 28.24 | 10.35 | 78.23 | 96.26 | 110.64 | 326.24 |
DN | 617.00 | 26.37 | 8.04 | 6.34 | 20.89 | 25.25 | 30.54 | 62.81 |
DU | 616.00 | 1.80 | 9.03 | 0.01 | 0.01 | 0.25 | 1.06 | 161.36 |
DV | 617.00 | 1.92 | 1.48 | 1.74 | 1.74 | 1.74 | 1.74 | 25.19 |
DY | 617.00 | 26.39 | 18.12 | 0.80 | 14.72 | 21.64 | 34.06 | 152.36 |
EB | 617.00 | 9.07 | 6.20 | 4.93 | 5.97 | 8.15 | 10.50 | 94.96 |
EE | 617.00 | 3.06 | 2.06 | 0.29 | 1.65 | 2.62 | 3.91 | 18.32 |
EG | 617.00 | 1731.25 | 1790.23 | 185.59 | 1111.16 | 1493.82 | 1905.70 | 30243.76 |
EH | 617.00 | 0.31 | 1.85 | 0.00 | 0.00 | 0.09 | 0.24 | 42.57 |
EL | 557.00 | 69.58 | 38.56 | 5.39 | 30.93 | 71.95 | 109.13 | 109.13 |
EP | 617.00 | 105.06 | 68.45 | 78.53 | 78.53 | 78.53 | 112.77 | 1063.59 |
EU | 617.00 | 69.12 | 390.19 | 3.83 | 4.32 | 22.64 | 49.09 | 6501.26 |
FC | 616.00 | 71.34 | 165.55 | 7.53 | 25.82 | 36.39 | 56.71 | 3030.66 |
FD | 617.00 | 6.93 | 64.75 | 0.30 | 0.30 | 1.87 | 4.88 | 1578.65 |
FE | 617.00 | 10306.81 | 11331.29 | 1563.14 | 5164.67 | 7345.14 | 10647.95 | 143224.68 |
FI | 617.00 | 10.11 | 2.93 | 3.58 | 8.52 | 9.95 | 11.52 | 35.85 |
FL | 616.00 | 5.43 | 11.50 | 0.17 | 0.17 | 3.03 | 6.24 | 137.93 |
FR | 617.00 | 3.53 | 50.18 | 0.50 | 0.50 | 1.13 | 1.51 | 1244.23 |
FS | 615.00 | 0.42 | 1.31 | 0.07 | 0.07 | 0.25 | 0.54 | 31.37 |
GB | 617.00 | 20.72 | 9.99 | 4.10 | 14.04 | 18.77 | 25.61 | 135.78 |
GE | 617.00 | 131.71 | 144.18 | 72.61 | 72.61 | 72.61 | 127.59 | 1497.35 |
GF | 617.00 | 14679.60 | 19352.96 | 13.04 | 2798.99 | 7838.27 | 19035.71 | 143790.07 |
GH | 617.00 | 31.49 | 9.86 | 9.43 | 25.03 | 30.61 | 36.86 | 81.21 |
GI | 617.00 | 50.58 | 36.27 | 0.90 | 23.01 | 41.01 | 67.93 | 191.19 |
GL | 616.00 | 8.53 | 10.33 | 0.00 | 0.12 | 0.34 | 21.98 | 21.98 |
Class | 617.00 | 0.18 | 0.38 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 |
# Checking for Null values, and sorting in descending order
sum().sort_values(ascending=False) df.isnull().
EL 60
BQ 60
CC 3
FS 2
CB 2
FL 1
FC 1
DU 1
GL 1
EE 0
EB 0
EU 0
DY 0
EH 0
EJ 0
DV 0
EP 0
EG 0
Id 0
DL 0
FD 0
FE 0
FI 0
FR 0
GB 0
GE 0
GF 0
GH 0
GI 0
DN 0
DH 0
DI 0
BR 0
AF 0
AH 0
AM 0
AR 0
AX 0
AY 0
AZ 0
BC 0
BD 0
BN 0
BP 0
BZ 0
AB 0
CD 0
CF 0
CH 0
CL 0
CR 0
CS 0
CU 0
CW 0
DA 0
DE 0
DF 0
Class 0
dtype: int64
= df.isnull().sum()*(100/df.shape[0])
missing_values = missing_values[missing_values>0].index.values
features_with_missing_values print(features_with_missing_values)
['BQ' 'CB' 'CC' 'DU' 'EL' 'FC' 'FL' 'FS' 'GL']
# Checking for duplicates
sum() df.duplicated().
0
# Checking number of 0s Vs 1 for class variable
df.Class.value_counts()
0 509
1 108
Name: Class, dtype: int64
=True) df.Class.value_counts(normalize
0 0.82
1 0.18
Name: Class, dtype: float64
Basic Details
- Contains null values - Need to replace those
- There are no duplicates
- out of 58 columns, 56 are numerical,
id
andEJ
columns are object type - There’s imbalance in the target variable i.e., ~83% rows are
0
and ~17% are1
Data Cleaning
OneHotEncoding
Note: I am performing OHE, however, I’m not using this as part of the current revision.
Column ‘EJ’ is categorical, as mentioned in the metadata (present on Kaggle). It contains string values ‘A’, and ‘B’. Hence, we need to perform convert string to numerical, in order to perform modelling.
'EJ'].value_counts() df[
B 395
A 222
Name: EJ, dtype: int64
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
= ["EJ"]
categorical_features = OneHotEncoder(sparse_output=False)
one_hot = ColumnTransformer([("one_hot",
transformer
one_hot,
categorical_features)],='drop'
remainder
)="pandas")
transformer.set_output(transform= transformer.fit_transform(df)
transformed_X #put in data frame for viewing
1], pd.DataFrame(transformed_X).head() transformed_X[:
( one_hot__EJ_A one_hot__EJ_B
0 0.00 1.00,
one_hot__EJ_A one_hot__EJ_B
0 0.00 1.00
1 1.00 0.00
2 0.00 1.00
3 0.00 1.00
4 0.00 1.00)
transformer.output_indices_
{'one_hot': slice(0, 2, None), 'remainder': slice(0, 0, None)}
=pd.DataFrame(transformed_X)
transformed_X_df
transformed_X_df.shape
(617, 2)
transformed_X_df.head()
one_hot__EJ_A | one_hot__EJ_B | |
---|---|---|
0 | 0.00 | 1.00 |
1 | 1.00 | 0.00 |
2 | 0.00 | 1.00 |
3 | 0.00 | 1.00 |
4 | 0.00 | 1.00 |
B
value of Column “EJ” is represented with 1, andA
value is represented as 0
Merging OneHotEncoded data into the original dataframe
transformed_X_df.shape, df.shape
((617, 2), (617, 58))
- We want to concat the two dataframes, and remove the original ‘EJ’ column. We would also want to remove any one of the ’one_hot__EJ_A’ or ’one_hot__EJ_B’ columns. This would help eliminate perfect multi collinearity that arises due to one-hot-encoding.
Handling Null Values
# Features having null values for target Class 1
= df[df.Class == 1].isnull().sum().sort_values(ascending=False)
features_with_missing_values_class_1 >0] features_with_missing_values_class_1[features_with_missing_values_class_1
EL 6
FC 1
FS 1
CC 1
dtype: int64
# Features having null values for target Class 0
= df[df.Class == 0].isnull().sum().sort_values(ascending=False)
features_with_missing_values_class_0 >0] features_with_missing_values_class_0[features_with_missing_values_class_0
BQ 60
EL 54
CC 2
CB 2
GL 1
DU 1
FS 1
FL 1
dtype: int64
'EL'].isnull() == True) & (df['BQ'].isnull() == False)] df[(df[
Id | AB | AF | AH | AM | AR | AX | AY | AZ | BC | BD | BN | BP | BQ | BR | BZ | CB | CC | CD | CF | CH | CL | CR | CS | CU | CW | DA | DE | DF | DH | DI | DL | DN | DU | DV | DY | EB | EE | EG | EH | EJ | EL | EP | EU | FC | FD | FE | FI | FL | FR | FS | GB | GE | GF | GH | GI | GL | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
88 | 228524bde6a3 | 0.35 | 4088.29 | 85.20 | 44.40 | 8.14 | 4.53 | 0.03 | 3.40 | 1.23 | 5677.41 | 20.83 | 446.66 | 87.79 | 790.96 | 257.43 | 34.48 | 0.71 | 70.88 | 4.27 | 0.03 | 1.05 | 0.63 | 49.86 | 1.16 | 42.80 | 44.88 | 278.62 | 0.24 | 0.17 | 126.16 | 50.68 | 20.81 | 17.08 | 1.74 | 64.71 | 9.40 | 0.29 | 1091.68 | 0.72 | B | NaN | 78.53 | 20.94 | 98.09 | 12.30 | 3740.07 | 14.56 | 20.56 | 0.79 | 0.07 | 42.45 | 81.68 | 2307.94 | 35.08 | 33.73 | 0.03 | 1 |
178 | 4572189340c4 | 0.40 | 2306.40 | 85.20 | 75.25 | 8.14 | 6.28 | 0.03 | 3.40 | 1.23 | 3872.64 | 22.25 | 206.92 | 344.64 | 450.02 | 257.43 | 768.90 | 0.58 | 139.79 | 30.71 | 0.04 | 1.05 | 0.91 | 27.83 | 1.34 | 36.16 | 36.34 | 136.93 | 0.24 | 0.25 | 139.01 | 67.78 | 31.23 | 0.74 | 1.74 | 12.09 | 10.14 | 2.64 | 2023.47 | 0.28 | B | NaN | 96.27 | 36.66 | 14.13 | 6.05 | 4688.72 | 10.22 | 5.44 | 1.51 | 0.07 | 10.19 | 78.46 | 10032.18 | 19.85 | 41.76 | 0.26 | 0 |
304 | 79b44ed25c29 | 0.95 | 6192.62 | 99.86 | 29.18 | 8.14 | 3.63 | 0.03 | 7.40 | 7.92 | 5663.43 | 23.66 | 203.14 | 55.83 | 2083.03 | 738.13 | 19.14 | 0.28 | 100.55 | 4.19 | 0.04 | 1.52 | 0.45 | 31.88 | 1.61 | 38.94 | 24.14 | 66.01 | 0.38 | 0.20 | 178.69 | 93.20 | 34.99 | 0.78 | 1.74 | 17.19 | 11.16 | 1.33 | 3114.19 | 0.15 | B | NaN | 102.75 | 11.28 | 41.88 | 2.95 | 13797.34 | 7.99 | 5.01 | 1.04 | 1.67 | 17.26 | 72.61 | 3595.33 | 35.58 | 36.58 | 0.13 | 1 |
330 | 81015c6c3404 | 6.16 | 18964.47 | 210.56 | 85.39 | 8.14 | 17.98 | 0.03 | 8.87 | 6.77 | 7259.05 | 19.07 | 1027.41 | 344.64 | 740.68 | 1510.07 | 536.22 | NaN | 633.53 | 50.08 | 0.08 | 2.20 | 0.67 | 46.87 | 1.16 | 7.03 | 23.92 | 416.26 | 0.56 | 0.31 | 311.52 | 65.77 | 62.81 | 0.01 | 3.04 | 16.29 | 27.71 | 6.36 | 6845.91 | 0.00 | A | NaN | 110.71 | 132.90 | NaN | 0.30 | 5676.74 | 12.77 | 0.17 | 54.95 | NaN | 31.64 | 296.04 | 12261.84 | 49.59 | 39.46 | 21.98 | 1 |
471 | bbb1066a9afd | 0.26 | 1390.04 | 85.20 | 11.60 | 10.22 | 5.66 | 0.07 | 3.40 | 3.11 | 4906.74 | 24.36 | 381.86 | 9.49 | 198.31 | 595.89 | 70.85 | 0.29 | 99.79 | 2.42 | 0.03 | 1.05 | 0.58 | 38.69 | 1.46 | 7.03 | 57.63 | 267.66 | 0.24 | 0.34 | 140.92 | 75.61 | 27.28 | 5.64 | 1.74 | 22.84 | 8.95 | 2.11 | 1337.13 | 0.52 | B | NaN | 169.60 | 20.06 | 28.87 | 23.51 | 6302.76 | 7.80 | 31.25 | 1.14 | 0.07 | 22.97 | 72.61 | 4646.47 | 23.36 | 46.54 | 0.06 | 1 |
490 | c5f4dc4ae7fb | 0.47 | 4877.54 | 85.20 | 35.00 | 8.14 | 5.86 | 0.03 | 3.40 | 5.08 | 3877.10 | 20.13 | 266.60 | 344.64 | 478.57 | 257.43 | 67.49 | 0.41 | 146.60 | 4.64 | 0.03 | 1.05 | 0.89 | 37.02 | 1.62 | 36.07 | 67.35 | 118.32 | 0.24 | 0.33 | 99.51 | 71.43 | 42.16 | 5.63 | 1.74 | 19.13 | 5.89 | 3.39 | 1740.03 | 0.41 | B | NaN | 78.53 | 34.91 | 24.09 | 12.31 | 51958.46 | 13.49 | 20.99 | 1.08 | 0.54 | 23.11 | 72.61 | 2497.72 | 42.53 | 19.32 | 0.05 | 1 |
516 | d0ecfae80796 | 0.44 | 975.51 | 85.20 | 121.20 | 8.14 | 4.98 | 0.04 | 3.40 | 3.95 | 3932.87 | 21.89 | 239.22 | 178.91 | 1325.22 | 257.43 | 34.08 | 0.52 | 161.75 | 10.67 | 0.03 | 1.05 | 0.72 | 44.09 | 1.96 | 33.79 | 33.68 | 136.14 | 0.24 | 0.27 | 116.02 | 80.36 | 17.32 | 9.99 | 1.74 | 16.11 | 6.36 | 0.84 | 698.75 | 1.65 | B | NaN | 98.60 | 20.97 | 67.98 | 14.40 | 8356.04 | 9.22 | 8.08 | 2.75 | 0.99 | 31.14 | 72.61 | 4999.16 | 33.40 | 8.26 | 0.11 | 1 |
We can hanlde missing values by any of the below techniques:
Dropping rows or columns - This can lead to missing out of valuable information in the data. Most often, not a suggested approach.
Replacing missing values with mean or median, i.e., P50 (for continuous data) - Effect of outliers will can play a role in replacing with mean. Replacing the values with median, is a good option.
Replacing missing values with mode (for categorical) - This is only for categorical , and may or may not work depending on the dataset you’re dealing with. This completely ignores the affect of features (i.e., feature importance and tree interpretation) have on the target variables.
Replacing missing values using KNN model - The k nearest neighbor algorithm is often used to impute a missing value based on how closely it resembles the points in the training set. The non-null features are used to predict the features having null values
For Sake of simplicity, let’s replace the missing values with mode for now
= df.mode().iloc[0]
modes modes
Id 000ff2bfdfe9
AB 0.26
AF 192.59
AH 85.20
AM 630.52
AR 8.14
AX 0.70
AY 0.03
AZ 3.40
BC 1.23
BD 1693.62
BN 20.48
BP 175.59
BQ 344.64
BR 51.22
BZ 257.43
CB 12.50
CC 0.46
CD 23.39
CF 0.51
CH 0.03
CL 1.05
CR 0.07
CS 13.78
CU 1.47
CW 7.03
DA 39.04
DE 183.07
DF 0.24
DH 0.37
DI 60.23
DL 10.35
DN 28.51
DU 0.01
DV 1.74
DY 0.80
EB 4.93
EE 0.29
EG 185.59
EH 0.00
EJ B
EL 109.13
EP 78.53
EU 3.83
FC 14.85
FD 0.30
FE 5088.92
FI 3.58
FL 0.17
FR 0.50
FS 0.07
GB 11.62
GE 72.61
GF 13.04
GH 9.43
GI 15.45
GL 21.98
Class 0.00
Name: 0, dtype: object
=True) df.fillna(modes, inplace
- Let’s now check for mising values in the dataframe
sum() df.isna().
Id 0
AB 0
AF 0
AH 0
AM 0
AR 0
AX 0
AY 0
AZ 0
BC 0
BD 0
BN 0
BP 0
BQ 0
BR 0
BZ 0
CB 0
CC 0
CD 0
CF 0
CH 0
CL 0
CR 0
CS 0
CU 0
CW 0
DA 0
DE 0
DF 0
DH 0
DI 0
DL 0
DN 0
DU 0
DV 0
DY 0
EB 0
EE 0
EG 0
EH 0
EJ 0
EL 0
EP 0
EU 0
FC 0
FD 0
FE 0
FI 0
FL 0
FR 0
FS 0
GB 0
GE 0
GF 0
GH 0
GI 0
GL 0
Class 0
dtype: int64
Handling categorical data type
Column ‘EJ’ is categorical, as mentioned in the metadata (present on Kaggle). It contains string values ‘A’, and ‘B’. Hence, let’s convert it’s dataType to categorical in pandas. We will also use cat.codes
to get the numeric codes for each category
'EJ'] = pd.Categorical(df.EJ) df[
df.EJ.head(), df.EJ.cat.codes.head()
(0 B
1 A
2 B
3 B
4 B
Name: EJ, dtype: category
Categories (2, object): ['A', 'B'],
0 1
1 0
2 1
3 1
4 1
dtype: int8)
Splitting Data into train and validion set.
Since this is not time-series specific data, it is safe to use Random splitter using scikit learn’s train_test_split to split into train and validation set. We get the column names as cols
, categorical column as cats
, dependent variable column as dep
, and any other irrelevant column such as id as irrelevant
.
= list(df) cols
0:5] cols[
['Id', 'AB', 'AF', 'AH', 'AM']
=["EJ"]
cats="Class"
dep="Id" irrelevant
[dep, irrelevant]
['Class', 'Id']
from sklearn.model_selection import train_test_split
= train_test_split(df, test_size=0.25) trn_df,val_df
We will convert the categorical columns into their numeric codes before proceeding further, since some of the models we’ll be building in a moment require that
= trn_df[cats].apply(lambda x: x.cat.codes)
trn_df[cats] = val_df[cats].apply(lambda x: x.cat.codes) val_df[cats]
def xs_y(df):
= df.drop(columns= [dep,irrelevant])
xs return xs,df[dep] if dep in df else None
= xs_y(trn_df)
trn_xs,trn_y = xs_y(val_df) val_xs,val_y
trn_xs.shape, trn_y.shape, val_xs.shape, val_y.shape
((462, 56), (462,), (155, 56), (155,))
2), trn_y.head(2), np.array(val_y), np.array(val_xs) trn_xs.head(
( AB AF AH AM AR AX AY AZ BC BD BN BP \
418 0.59 1915.51 85.20 35.78 8.14 5.63 0.03 13.38 5.11 6110.97 26.13 296.39
368 0.77 7205.59 85.20 25.15 8.14 4.11 0.03 3.40 11.73 6557.07 27.19 282.62
BQ BR BZ CB CC CD CF CH CL CR CS CU \
418 53.20 1282.50 257.43 108.79 0.78 104.23 8.81 0.04 1.09 0.47 48.18 1.66
368 121.36 1570.21 587.37 32.66 0.47 291.45 9.91 0.05 1.05 0.52 28.47 0.99
CW DA DE DF DH DI DL DN DU DV DY EB \
418 39.07 38.90 220.53 0.24 0.42 167.99 110.52 49.94 0.01 1.74 34.57 10.39
368 7.03 35.40 192.34 0.24 0.27 97.98 73.49 20.19 3.34 1.74 50.13 9.85
EE EG EH EJ EL EP EU FC FD FE FI FL \
418 3.29 2387.15 0.00 0 109.13 78.53 13.43 110.01 0.30 16268.12 9.68 0.17
368 3.02 1844.60 0.24 1 12.54 78.53 10.88 28.58 9.65 3073.15 10.25 21.77
FR FS GB GE GF GH GI GL
418 2.49 0.07 23.79 72.61 52081.01 27.90 50.31 21.98
368 3.73 0.07 25.39 72.61 1260.86 27.14 13.23 0.05 ,
418 0
368 1
Name: Class, dtype: int64,
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0]),
array([[2.47834000e-01, 5.39291555e+03, 8.52001470e+01, ...,
2.43055720e+01, 5.35104600e+01, 2.19780000e+01],
[3.33294000e-01, 4.23949152e+03, 1.29137052e+02, ...,
2.91317090e+01, 4.90428960e+01, 9.52641510e-02],
[1.19644000e-01, 3.38098442e+03, 8.52001470e+01, ...,
3.40769180e+01, 4.74405400e+01, 1.70500000e-01],
...,
[1.53828000e-01, 5.56698010e+03, 1.73126211e+02, ...,
3.20489730e+01, 7.83071120e+01, 2.19780000e+01],
[5.08487000e-01, 1.63212907e+03, 8.52001470e+01, ...,
3.44825070e+01, 8.56476000e+00, 2.19780000e+01],
[3.97389000e-01, 2.30639794e+03, 8.52001470e+01, ...,
1.98478140e+01, 4.17564200e+01, 2.55364486e-01]]))
Building the RandomForest Model
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import log_loss
= RandomForestClassifier(100, min_samples_leaf=5)
rf ;
rf.fit(trn_xs, trn_y) log_loss(np.array(val_y), rf.predict_proba(val_xs))
0.2200597690773292
Building Tensorflow RandomForest
try: import tensorflow_decision_forests as tfdf
except ModuleNotFoundError:
!pip3 install tensorflow_decision_forests
import tensorflow_decision_forests as tfdf
/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/__init__.py:98: UserWarning: unable to load libtensorflow_io_plugins.so: unable to open file: libtensorflow_io_plugins.so, from paths: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
warnings.warn(f"unable to load libtensorflow_io_plugins.so: {e}")
/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/__init__.py:104: UserWarning: file system plugins are not loaded: unable to open file: libtensorflow_io.so, from paths: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']
warnings.warn(f"file system plugins are not loaded: {e}")
We need to convert the pandas dataframe to tensorflow-keras Dataframe, for the keras model to fit
= tfdf.keras.pd_dataframe_to_tf_dataset(trn_df, label='Class')
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(val_df, label='Class') valid_ds
Warning: Some of the feature names have been changed automatically to be compatible with SavedModels because fix_feature_names=True.
Warning: Some of the feature names have been changed automatically to be compatible with SavedModels because fix_feature_names=True.
We need to define weights for each class, in order to provide to the randomForestModel to calculate the balanced logarithmic loss based on the formula given in evaluation of our problem statement
= np.sum(df.Class==0)
N_0 = np.sum(df.Class==1)
N_1 print(f'Classes with label 0: {N_0}, Classes with label 1: {N_1}')
= N_0 + N_1
total
= (1 / N_0) * (total / 2.0)
weight_0 = (1 / N_1) * (total / 2.0)
weight_1
= {0: weight_0, 1: weight_1}
class_weight
print('Weight for class 0: {:.2f}'.format(weight_0))
print('Weight for class 1: {:.2f}'.format(weight_1))
Classes with label 0: 509, Classes with label 1: 108
Weight for class 0: 0.61
Weight for class 1: 2.86
= tfdf.keras.RandomForestModel()
rf compile(metrics=["accuracy", "binary_crossentropy"])
rf.=class_weight) rf.fit(train_ds, class_weight
Use /tmp/tmphlkyh10a as temporary training directory
Reading training dataset...
Training dataset read in 0:00:06.683765. Found 462 examples.
Training model...
Model trained in 0:00:00.432826
Compiling model...
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7a664bf19c60> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
[INFO 23-07-06 18:19:45.1798 UTC kernel.cc:1242] Loading model from path /tmp/tmphlkyh10a/model/ with prefix 6e8b67976b1d4515
[INFO 23-07-06 18:19:45.2350 UTC decision_forest.cc:660] Model loaded with 300 root(s), 13538 node(s), and 56 input feature(s).
[INFO 23-07-06 18:19:45.2351 UTC abstract_model.cc:1311] Engine "RandomForestOptPred" built
[INFO 23-07-06 18:19:45.2351 UTC kernel.cc:1074] Use fast generic engine
<keras.callbacks.History at 0x7a66dd4b12d0>
= rf.predict(valid_ds)
val_pred = rf.evaluate(x=valid_ds,return_dict=True)
evaluation = evaluation["accuracy"]
accuracy= evaluation["binary_crossentropy"] cross_entropy
1/1 [==============================] - 0s 136ms/step
1/1 [==============================] - 1s 654ms/step - loss: 0.0000e+00 - accuracy: 0.9290 - binary_crossentropy: 0.2167
- Loss on TF RandomForest is less than the original RandomForest. Hence, we will validate the test set on TF RandomForest
Checking on the test set
test_df
Id | AB | AF | AH | AM | AR | AX | AY | AZ | BC | BD | BN | BP | BQ | BR | BZ | CB | CC | CD | CF | CH | CL | CR | CS | CU | CW | DA | DE | DF | DH | DI | DL | DN | DU | DV | DY | EB | EE | EG | EH | EJ | EL | EP | EU | FC | FD | FE | FI | FL | FR | FS | GB | GE | GF | GH | GI | GL | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 00eed32682bb | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | A | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
1 | 010ebe33f668 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | A | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
2 | 02fa521e1838 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | A | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
3 | 040e15f562a2 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | A | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
4 | 046e85c7cc7f | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | A | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
= pd.read_csv(path/"test.csv")
test_df 'EJ'] = pd.Categorical(test_df.EJ)
test_df['EJ'] = test_df['EJ'].cat.codes
test_df[= tfdf.keras.pd_dataframe_to_tf_dataset(test_df)
test_df =rf.predict(test_df)
predictions= [[round(abs(i-1), 8), i] for i in predictions.ravel()]
n_predictionsprint(n_predictions)
Warning: Some of the feature names have been changed automatically to be compatible with SavedModels because fix_feature_names=True.
1/1 [==============================] - 0s 128ms/step
[[0.6666669, 0.3333331], [0.6666669, 0.3333331], [0.6666669, 0.3333331], [0.6666669, 0.3333331], [0.6666669, 0.3333331]]
Submission
= pd.read_csv(path/"sample_submission.csv")
sample_submission 'class_0', 'class_1']] = n_predictions
sample_submission[['submission.csv', index=False) sample_submission.to_csv(
sample_submission.head()
Id | class_0 | class_1 | |
---|---|---|---|
0 | 00eed32682bb | 0.67 | 0.33 |
1 | 010ebe33f668 | 0.67 | 0.33 |
2 | 02fa521e1838 | 0.67 | 0.33 |
3 | 040e15f562a2 | 0.67 | 0.33 |
4 | 046e85c7cc7f | 0.67 | 0.33 |
= 'icr-identify-age-related-conditions'
comp if not isKaggleEnv and not iskaggle:
'kashishmukheja', 'ICR Prediction - TF RandomForest',
push_notebook(='ICR Prediction - TF RandomForest',
titlefile='K_ICR-Prediction-2.ipynb',
=comp, private=False, gpu=True) competition